#######################################################################################################
#### SPACEWHALE PROJECT
#### S3 File
#### Script for testing a pytorch, convolutional neural net, using the pre-trained resnet18 model  ####
#### Authors:  Hieu Le & Grant Humphries
#### Date: August 2018
#### This script was written for the Spacewhale project 
#### This script was written based on the Pytorch transfer learning tutorial: 
#### https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 
#######################################################################################################
#### Usage examples (Linux)
####
####  $python testing_script.py --data_dir /home/ghumphries/spacewhale/test --modtype densenet --model MODEL1 --epoch 24
####
#######################################################################################################
#### Setup information
####    To run this script, ensure that you have folders named exactly the same as those in the training data folder
####    For example: 
####    ./test/Water 
####    ./test/Whale
####    IMPORTANT:
####        The images that you want to test should all live in the target folder. For example, if you only want to test for
####        water, then place all the images in the ./test/Water folder. If you want to test for whales, place all the images in 
####        the ./test/Whale folder
####        The data_dir argument should point to the directory ABOVE the training folders.
####        For example, if your directory is:  /home/user/spacewhale/testingdata/Water
####        then --data_dir /home/user/spacewhale/testingdata
#######################################################################################################
### Library imports

from __future__ import print_function, division

from PIL import Image
from PIL import ImageFilter
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from m_util import *
import matplotlib.pyplot as plt
import time
import os
import copy
import argparse
from model import define_model

### Add arguments
parse = argparse.ArgumentParser()
parse.add_argument('--data_dir')
parse.add_argument('--model')
parse.add_argument('--modtype', type=str)
parse.add_argument('--epoch',type=int,default=24)
opt = parse.parse_args()

### Bring in the spacewhale class (from m_util.py)
s = spacewhale(opt)

epoch_to_use = 'epoch_'+str(opt.epoch)+'.pth' # Which model epoch to use
trained_model = os.path.join('./trained_model',opt.model,epoch_to_use) 

data_dir = opt.data_dir  # Where's the data?

test_transforms = s.data_transforms['test']
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_tensor_type('torch.cuda.FloatTensor')

model_ft = define_model(name = opt.modtype) #Which model type to use?
model_ft = model_ft.to(device)

model_ft.load_state_dict(torch.load(trained_model))
model_ft.eval()


#image_dataset = datasets.ImageFolder(data_dir, s.data_transforms['test'])
image_datasets = ImageFolderWithPaths(data_dir, s.data_transforms['test'])
#image_datasets = datasets.ImageFolder(data_dir, s.data_transforms['test'])
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=10,shuffle=False, num_workers=16)


class_names = image_datasets.classes
keylist = [x for x in range(len(class_names))]
d = {key: value for (key, value) in zip(keylist,class_names)}

print(epoch_to_use)
print(d)

s.test_dir(device,model_ft,dataloaders)



